function [Lkhood,W,Mean,Var,lam,PK_Y,X]=gm_lambda2(mrFile,K,lam0,W0,Mean0,Var0,THRESHOLD)
%[Lkhood,W,Mean,Var,Lam, PK_Y,X]=gm_lambda2(mrFile,K,Lam0,W0,Mean0,Var0,THRESHOLD)
% Input:
%  mrFile: (string) file name of a *.img
%  K: (scalar) number of clusters.
%  Lam0: (K*1 column vector) initial values of lambda
%  W0: (K*1 column vector) initial values of cluster weight
%       When this argument is set as an empty vector, default values are used.
%  Mean0: (K*1 column vector) initial values of cluster mean
%       When this argument is set as an empty vector, default values are used.
%  Var0: (K*1 column vector) initial values of cluster variance
%       When this argument is set as an empty vector, default values are used.
%  THRESHOLD: (Kx2 column vector) [csf_min csf_max; gray_min gray_max; white_min white_max]
%             *_min can be randomly assigned since only *_max are used. 
% Output:
%  Lkhood: (vector) calculated likelihood 
%  W: (K*1 column vector) estimated cluster weights
%  Mean: (K*1 column vector) estimated cluster means
%  Var: (K*1 column vector) estimated cluster variances
%  lam: (K*1 column vector) estimated lambdas
%  PK_Y: (nrows*ncolumns*nslices*K 4-D matrix) belonging probability p(k|y) for each voxel.
%  X: (nrows*ncolumns*nslices matrix) segmentation result
%
% Note:
%  This program assumes that intensity values are all integers.
%
%Example:
%[Lk W M V Lam PK_Y]=gm_lambda2('mr.img',3,[1;2;2],[0.05; 0.5; 0.45],[],[],[0 0.1;0 065;0 0.5]);
%

MAXITER = 1000;%maximum number of EM iterations

mr = readimg(mrFile);
%====================== set up data =====================================
ImgDim = [size(mr,1) size(mr,2) size(mr,3)]; %get image dimensions
BrainIdx = find(mr>0); %assumes that intensities of brain voxels are greater than zero 
N = length(BrainIdx);   %N: #brain voxels
y = reshape(mr(BrainIdx),1,N); %y contains brain voxels only

%---------------------- setup I0(intensity table) and H(histogram)-----
MaxI = max(y);  %find max intensity
I = [1:MaxI];%I=[1 2 3 ... MaxI];
H = hist(y,I);%H=[#1 #2 #3 ... #MaxI];
logI = log(I);  %store log(I) into a variable so that log(I) is calculated only once

%---------------------- initialize I ----------------------------------
Ilam = zeros(K,MaxI);  %intensities after lambda transform
for k = 1:K, 
    Ilam(k,:) = lam_trans(I,lam0(k));
end

%============= set default value for input parameters ========================
if isempty(W0), 
    W0 = ones(K,1)/K; 
end

if (isempty(Mean0) | isempty(Var0)),
 B = sort(y,2);
 S = split_mat(W0,B); %get end points of each section
    
 if(isempty(Mean0)),  %calculate Mean0
   Mean0 = zeros(K,1);
   for k=1:K,
         Mean0(k) = mean(lam_trans(B(S(1,k):S(2,k)),lam0(k)));  
   end
 end
 
 if isempty(Var0),   %calculate Var0
   Var0 = zeros(K,1);
   for k=1:K,
         Var0(k) = var(lam_trans(B(S(1,k):S(2,k)),lam0(k)));   
   end
 end
    
 clear B S;
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%% end of set default value %%%%%%%%%%%%%%%%%%%%%%%%%%%%%
W = W0;
Mean = Mean0;
Var = Var0;
lam = lam0;

Lkhood = zeros(1,MAXITER);

PK_I = zeros(K,MaxI);    %p(k|i)
PI   = zeros(1,MaxI);    %p(i)
PI_K = zeros(K,MaxI);    %p(i|k)


J = zeros(K,MaxI);   %J(k,i), Jacobian


%======================= begin of EM loop ===============================
for iter = 1:MAXITER,
 %calculate Jacobian and I^(lam)
  for k=1:K
    J(k,:) = I.^(lam(k)-1);
    Ilam(k,:) = lam_trans(I,lam(k));
  end

  %--------------- E step: Calculate PK_I --------------------
  PI_K = normpdf(Ilam,repmat(Mean,1,MaxI),repmat(sqrt(Var),1,MaxI)).*J; %p(i|k)=g(i_lambda|k)*J(i,k)
  PIK = PI_K.*repmat(W,1,MaxI);    %p(i,k)=p(i|k)*p(k)
  PI = sum(PIK,1);                  %p(i)=sum_k(p(i,k))
  if any(PI<=0),
    fprintf('\np(i)<=0!!');
    PI(find(PI<=0)) = realmin;
  end 
  PK_I = PIK./repmat(PI,K,1);  %p(k|i)=p(i,k)/p(i)

  
 
  %calculate log likelihood
  Lkhood(1,iter) = sum(log(PI).*H);
  fprintf('iter=%d, Li = %f, ',iter,Lkhood(iter));
 
  EARLY_BREAK = 0;
  if EARLY_BREAK==1,
    if (iter>1)&((Lkhood(iter)-Lkhood(iter-1))<0.000001), 
        break; 
    end
  end
  %----------------- M step: ------------------------
  %---------Update W---------
  fprintf('W: ');
  for k=1:K,
    W(k) = sum(PK_I(k,:).*H)/N;
    fprintf('%g ,',W(k));
  end
  fprintf('sumW: %g ',sum(W));
  
  hit = find(W>THRESHOLD(:,2));  %check if any cluster hits the upper bound
  nohit = find(~(W>THRESHOLD(:,2)));
  
  if(length(hit)==1)   %only one cluster exceeds its upper bound
      W(hit)=THRESHOLD(hit,2);   
      W(nohit)=(1-W(hit))*W(nohit)./sum(W(nohit)); %adjusted for the other clusters
  elseif(length(hit)==2) %two clusters exceed their upper bounds respectively
      W(hit)=THRESHOLD(hit,2);
      W(nohit)=1-sum(W(hit));%adjusted for the rest cluster
  end
 if(~isempty(hit)),
    fprintf('normalized W:');
     for k=1:K,
         fprintf('%g ,',W(k));
     end
 end

  if any(W<=0),
    fprintf('W(k)=0!!!'); 
    W(find(W<=0)) = 0.000001;
    W = W./sum(W);
  end
  
  %---------Update Mean---------
  for k=1:K,
      Mean(k) = sum(PK_I(k,:).*Ilam(k,:).*H)/sum(PK_I(k,:).*H);
  end
  %---------Update Var---------
  for k=1:K,
      Var(k) = sum(H.*PK_I(k,:).*((Ilam(k,:)-Mean(k)).^2))/sum(PK_I(k,:).*H); 
  end
  if any(Var <= 0.00001),
    fprintf('Attention!! Var is too small (<0.00001) for some k.'); 
    Var(find(Var<=0.00001)) = 0.00008;
  end
  %-------Update Lambda---------
  fprintf('lam: ');
  for k=1:K,
    sigma2 = Var(k);
    mu =Mean(k);
    lam(k) = fzero(@(lam) g(lam,mu,sigma2,I,logI,H,PK_I(k,:)),3);
    if(lam(k)<0) lam(k) = 0; end
    fprintf('%g,',lam(k));
  end
  fprintf('\n');
  %--
end%for iter = 1:MAXITER 
%======================= end of EM loop ==============================

%-- convert PK_I to probability image
if nargout>=6,
  PK_Y = zeros(K,prod(ImgDim)); %whole slice
  PK_Y(:,BrainIdx) = PK_I(:,y);
  PK_Y = reshape(PK_Y',[ImgDim K]);
end

if nargout>=7,
 X = zeros(ImgDim);
 [temp,X]  =max(PK_Y,[],4);
 X(find(mr==0))=0;
end 

return%function

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [SEP]=split_mat(W,Y)
%
% SEP: section end points, 2xK matrix
% W: weigth vector, Kx1 vector
% Y: 1xN vector
N=length(Y);
K=length(W);
SEP=zeros(2,K);
p = 0; q = 0;
for k = 1:K
  p = q+1; q = fix(N*sum(W(1:k)));
  SEP(:,k)=[p;q];
end
return

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [Y_lambda]=lam_trans(Y,lambda)
%[Y_lambda]=lam_trans(Y,lambda)
%
% Y_lambda=(Y.^(lambda)-1)/(lambda)

Y_lambda = zeros(size(Y));

if((0 < lambda)),
    Y_lambda =(Y.^(lambda)-1)/lambda;
elseif (lambda==0),
    Y_lambda = log(Y);
else
  fprintf('lam_trans(): lambda = %g\n',lambda);
  fprintf('lam_trans(): lambda must be in (0,inf)\n'); 
  return;
end

return
%====================================================================================
%------------------------------------------------
function f = g(lam,mu,var,I,logI,H,pk_i)
%lam: scalar
%mu: scalar
%var: scalar
%I: vector
%logI: vector
%H: vector
%pk_i: vector
Ilam = lam_trans(I,lam);
F1 = (Ilam-mu).*(logI.*Ilam*lam+logI-Ilam);
F1 = F1.*H.*pk_i;
F2 = logI.*pk_i.*H;
f = sum(F1)/(lam*var)-sum(F2);

return;
%------------------------------------------------
